"""
This document defined the FSS key class used for VDPF, through which the gen and eval process for DPF can be verifiable.
The implementation is based on the work of Leo de Castro, Antigoni Polychroniadou. Lightweight, Maliciously Secure Verifiable Function Secret Sharing. EUROCRYPT 2022, LNCS 13275, pp. 150–179, 2022.
For reference, see the `paper <https://doi.org/10.1007/978-3-031-06944-4_6>`_.
"""
#  This file is part of the NssMPClib project.
#  Copyright (c) 2024 XDU NSS lab,
#  Licensed under the MIT license. See LICENSE in the project root for license information.

import torch

from NssMPC.common.random import PRG
from NssMPC.common.utils import convert_tensor
from NssMPC.config import HALF_RING, LAMBDA, data_type, BIT_LEN, DEVICE, PRG_TYPE
from NssMPC.crypto.aux_parameter import Parameter
from NssMPC.crypto.aux_parameter.function_secret_sharing_keys import CW, CWList


class VDPFKey(Parameter):
    """
    The FSS key class for verifiable distributed point function (VDPF).

    This class implements the generation method for the key used for VDPF evaluation.
    It includes methods for generating and managing the parameters used in VDPF protocol.

    ATTRIBUTES:
        * **s** (*torch.Tensor*): A binary string (λ bits) generated by the PRG for the root node.
        * **cw_list** (:class:`CWList <NssMPC.crypto.aux_parameter.function_secret_sharing_keys.cw.CWList>`)
        * **ocw** (*torch.Tensor*): The extra check word used for VDPF calculation.
        * **cs** (*torch.Tensor*): The parameter required for verification.
    """

    def __init__(self):
        """
        Initialize the VDPFKey object.

        This method initializes the seed `s` to **None**, the list of correction words `cw_list` to a **CWList** object,
        the extra correction word `ocw` to **None**, and sets the verification parameter `cs` to **None**.
        """
        self.s = None
        self.cw_list = CWList()
        self.ocw = None
        self.cs = None

    @staticmethod
    def gen(num_of_keys, alpha, beta):
        """
        Generate keys for VDPF.

        This method can generate multiple keys for VDPF, which can be used for the following evaluation.

        :param num_of_keys: Number of keys to generate.
        :type num_of_keys: int
        :param alpha: The comparison point of VDPF.
        :type alpha: RingTensor
        :param beta: The output value if the comparison is true.
        :type beta: RingTensor
        :return: The keys of all the involved parties (two parties).
        :rtype: tuple
        """
        return vdpf_gen(num_of_keys, alpha, beta)


def vdpf_gen(num_of_keys, alpha, beta):
    """
    Generate keys for VDPF.

    This method can generate multiple keys for VDPF, which can be used for the following evaluation.

    :param num_of_keys: Number of keys to generate.
    :type num_of_keys: int
    :param alpha: The comparison point of VDPF.
    :type alpha: RingTensor
    :param beta: The output value if the comparison is true.
    :type beta: RingTensor
    :return: The keys of all the involved parties (two parties)
    :rtype: tuple
    """
    seed_0 = torch.randint(-HALF_RING, HALF_RING - 1, [num_of_keys, LAMBDA // BIT_LEN], dtype=data_type, device=DEVICE)
    seed_1 = torch.randint(-HALF_RING, HALF_RING - 1, [num_of_keys, LAMBDA // BIT_LEN], dtype=data_type, device=DEVICE)
    # 产生伪随机数产生器的种子

    prg = PRG(PRG_TYPE, device=DEVICE)
    prg.set_seeds(seed_0)
    s_0_0 = prg.bit_random_tensor(LAMBDA)
    prg.set_seeds(seed_1)
    s_0_1 = prg.bit_random_tensor(LAMBDA)

    k0 = VDPFKey()
    k1 = VDPFKey()

    k0.s = s_0_0
    k1.s = s_0_1

    s_last_0 = s_0_0
    s_last_1 = s_0_1

    t_last_0 = 0
    t_last_1 = 1
    prg = PRG(PRG_TYPE, DEVICE)

    for i in range(alpha.bit_len):
        s_l_0, t_l_0, s_r_0, t_r_0 = CW.gen_dpf_cw(prg, s_last_0, LAMBDA)
        s_l_1, t_l_1, s_r_1, t_r_1 = CW.gen_dpf_cw(prg, s_last_1, LAMBDA)

        cond = (alpha.get_tensor_bit(alpha.bit_len - 1 - i) == 0).view(-1, 1)

        l_tensors = [s_l_0, s_l_1, t_l_0, t_l_1]
        r_tensors = [s_r_0, s_r_1, t_r_0, t_r_1]

        keep_tensors = [torch.where(cond, l, r) for l, r in zip(l_tensors, r_tensors)]
        lose_tensors = [torch.where(cond, r, l) for l, r in zip(l_tensors, r_tensors)]

        s_keep_0, s_keep_1, t_keep_0, t_keep_1 = keep_tensors
        s_lose_0, s_lose_1, t_lose_0, t_lose_1 = lose_tensors

        s_cw = s_lose_0 ^ s_lose_1

        t_l_cw = t_l_0 ^ t_l_1 ^ ~cond ^ 1
        t_r_cw = t_r_0 ^ t_r_1 ^ ~cond

        cw = CW(s_cw=s_cw, t_cw_l=t_l_cw, t_cw_r=t_r_cw, lmd=LAMBDA)

        k0.cw_list.append(cw)
        k1.cw_list.append(cw)

        t_keep_cw = torch.where(cond, t_l_cw, t_r_cw)

        s_last_0 = s_keep_0 ^ (t_last_0 * s_cw)
        s_last_1 = s_keep_1 ^ (t_last_1 * s_cw)

        t_last_0 = t_keep_0 ^ (t_last_0 * t_keep_cw)
        t_last_1 = t_keep_1 ^ (t_last_1 * t_keep_cw)

    # TODO: Hash function
    # prg.set_seeds(torch.cat((s_last_0, alpha.tensor.unsqueeze(1)), dim=1))
    prg.set_seeds(s_last_0 + alpha.tensor.unsqueeze(1))
    pi_0 = prg.bit_random_tensor(4 * LAMBDA)
    # prg.set_seeds(torch.cat((s_last_1, alpha.tensor.unsqueeze(1)), dim=1))
    prg.set_seeds(s_last_1 + alpha.tensor.unsqueeze(1))
    pi_1 = prg.bit_random_tensor(4 * LAMBDA)

    s_0_n_add_1 = s_last_0
    s_1_n_add_1 = s_last_1

    # t_0_n_add_1 = s_0_n_add_1 & 1
    # t_1_n_add_1 = s_1_n_add_1 & 1
    cs = pi_0 ^ pi_1
    k0.cs = k1.cs = cs
    k0.ocw = k1.ocw = pow(-1, t_last_1) * (
            beta.tensor - convert_tensor(s_0_n_add_1) + convert_tensor(s_1_n_add_1))

    return k0, k1
